package com.elong.pb.newdda.client.jdbc;

import com.elong.pb.newdda.client.datasource.DataSourceContainer;
import com.elong.pb.newdda.client.datasource.MasterSlaveDataSource;
import com.elong.pb.newdda.client.jdbc.adapter.AbstractConnectionAdapter;
import com.elong.pb.newdda.client.router.SqlRouterEngine;
import com.elong.pb.newdda.client.router.result.router.SqlStatementType;
import com.elong.pb.newdda.client.router.rule.ShardingRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

/**
 * 分区的链接
 * User: zhangyong
 * Date: 2016/3/31
 * Time: 18:53
 * To change this template use File | Settings | File Templates.
 */
public class ShardingConnection extends AbstractConnectionAdapter {

    private final static Logger logger = LoggerFactory.getLogger(ShardingConnection.class);

    private final ShardingRule shardingRule;

    private SqlRouterEngine sqlRouterEngine;

    private final Map<String, Connection> connectionMap = new HashMap<String, Connection>();

    public ShardingConnection(ShardingRule shardingRule) {
        this.shardingRule = shardingRule;
        this.sqlRouterEngine = new SqlRouterEngine(shardingRule);
    }

    @Override
    protected Collection<Connection> getConnections() {
        return connectionMap.values();
    }

    @Override
    public Statement createStatement() throws SQLException {
        return new ShardingStatement(this, sqlRouterEngine);
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLException {
        return new ShardingStatement(this, sqlRouterEngine, resultSetType, resultSetConcurrency);
    }

    @Override
    public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return new ShardingStatement(this, sqlRouterEngine, resultSetType, resultSetConcurrency);
    }

    @Override
    public PreparedStatement prepareStatement(String sql) throws SQLException {
        return new ShardingPreparedStatement(this, sqlRouterEngine, sql);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLException {
        return new ShardingPreparedStatement(this, sqlRouterEngine, sql, resultSetType, resultSetConcurrency);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) throws SQLException {
        return new ShardingPreparedStatement(this, sqlRouterEngine, sql, resultSetType, resultSetConcurrency, resultSetHoldability);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException {
        return new ShardingPreparedStatement(this, sqlRouterEngine, sql, autoGeneratedKeys);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException {
        return new ShardingPreparedStatement(this, sqlRouterEngine, sql, columnIndexes);
    }

    @Override
    public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException {
        return new ShardingPreparedStatement(this, sqlRouterEngine, sql, columnNames);
    }

    //======================================================== 通过数据源名称 以及 sql类型取得链接 ========================================================

    public Connection getConnection(final String dataSourceName, final SqlStatementType sqlStatementType) throws SQLException {
        Connection connection = getConnectionInternal(dataSourceName, sqlStatementType);
        return connection;
    }

    private Connection getConnectionInternal(final String dataSourceName, final SqlStatementType sqlStatementType) throws SQLException {
        if (connectionMap.containsKey(dataSourceName)) {
            return connectionMap.get(dataSourceName);
        }
        DataSourceContainer dataSourceContainer = shardingRule.getDataSourceContainer();
        MasterSlaveDataSource masterSlaveDataSource = dataSourceContainer.getContainer().get(dataSourceName);
        //TODO 第一版默认取主库
        DataSource dataSource = masterSlaveDataSource.getMasterDataSource();
        Connection connection = dataSource.getConnection();
        connectionMap.put(dataSourceName, connection);
        return connection;
    }

}
